import os.path

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from utils.data.LibriSpeech_set import LibriSpeech_set
from utils.download.ERes2Net import ERes2Net
from utils.metrics import cal_fpr_fnr, cal_eer_threshold, cal_min_dcf_threshold
from utils.models.ECAPA_TDNN import ECAPA_TDNN
from utils.transform.feature_transform import feature_transform

model_path = os.path.join('C:/Users/24667/Documents/毕设数据/',
                          'ERes2Net_MFCC_AAMLoss/2024_05_04_20_10_38/ERes2Net/best_model/ERes2Net.pth')
use_feature_extraction = 'MFCC'
log_dir = os.path.join('./logs', 'ERes2Net特征提取对比', 'MFCC')
dataset_cfg = {
    'data_path': './dataset/LibriSpeech',
    'preprocess_cfg': {
        'max_duration': 3,
        'target_sample_rate': 16000,
        'keep_audio_channel': False,
        'use_speed_perturbation': False,
        'speed_perturbation_sequence': (0.9, 0.95, 1, 1.05, 1.1),
        'add_noise': False,
        'max_snr': 50,
        'min_snr': 10
    },
}
dataloader_cfg = {
    'batch_size': 64,
    'num_workers': 16,
    'pin_memory': True
}
feature_extraction_cfg = {
    'use_mfcc_cms': True,
    'feature_transpose': True,
    'feature_extraction_cfg': {
        'Spectrogram': {
            'n_fft': 1024,
            'win_length': 400,
            'hop_length': 160,
            'window_fn': torch.hamming_window,
            'normalized': True
        },
        'MFCC': {
            'n_mfcc': 80,
            'log_mels': True,
            'melkwargs': {
                'n_fft': 400,
                'hop_length': 160,
                'n_mels': 80
            }
        },
        'fbank': {
            'num_mel_bins': 80
        }
    },
    'Augmentations_cfg': {
        'time_masking': False,
        'time_mask_param': 5,
        'freq_masking': False,
        'freq_mask_param': 10,
        'iid_masks': True
    }
}
model_cfg = {
    'ECAPA_TDNN': {
        'channels': 512,
        'bottleneck': 128,
        'scale': 8
    }
}


def set_data(_dataset_cfg, _dataloader_cfg):
    _test_set = LibriSpeech_set('test-other', **dataset_cfg)
    _test_loader = DataLoader(_test_set, shuffle=False, **dataloader_cfg)
    return _test_set, _test_loader


if __name__ == "__main__":
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    print('正在准备……')
    test_set, test_loader = set_data(dataset_cfg, dataloader_cfg)
    transform = feature_transform(use_feature_extraction, **feature_extraction_cfg)
    transform = transform.to(device=device)
    # model = ECAPA_TDNN(**model_cfg.get('ECAPA_TDNN', {}))
    # model = model.to(device=device)
    model = ERes2Net(**model_cfg.get('ERes2Net', {}))
    model = model.to(device=device)
    model.eval()
    model.load_state_dict(torch.load(model_path))
    writer = SummaryWriter(log_dir)
    with torch.no_grad():
        for i in range(5):
            print(f'正在进行第{i + 1}轮测试')
            all_labels = torch.empty(0, device=device)
            all_identities = torch.empty(0, device=device)
            for waveforms, labels in tqdm(test_loader, desc='计算特征'):
                waveforms = waveforms.to(device=device)
                labels = labels.to(device=device)
                features = transform(waveforms)
                identities = model(features)
                all_labels = torch.cat([all_labels, labels], dim=0)
                all_identities = torch.cat([all_identities, identities], dim=0)
            print('正在计算模型指标……')
            all_identities = F.normalize(all_identities, dim=-1)
            score_matrix = F.linear(all_identities, all_identities)
            labels_matrix = torch.eq(all_labels.unsqueeze(0), all_labels.unsqueeze(1))
            all_score = score_matrix.view(-1)
            all_labels = labels_matrix.view(-1)
            fpr, fnr, all_score = cal_fpr_fnr(all_score, all_labels)
            eer, eer_threshold = cal_eer_threshold(fpr, fnr, all_score)
            min_dcf, min_dcf_threshold = cal_min_dcf_threshold(fpr, fnr, all_score)
            print(f'EER: {eer:.4%}    '
                  f'EER threshold: {eer_threshold:.8f}    '
                  f'minDCF: {min_dcf:.8f}    '
                  f'minDCF threshold: {min_dcf_threshold:.8f}')
            writer.add_scalar(tag='test/EER',
                              scalar_value=eer,
                              global_step=i + 1)
            writer.add_scalar(tag='test/EER_threshold',
                              scalar_value=eer_threshold,
                              global_step=i + 1)
            writer.add_scalar(tag='test/minDCF',
                              scalar_value=min_dcf,
                              global_step=i + 1)
            writer.add_scalar(tag='test/minDCF_threshold',
                              scalar_value=min_dcf_threshold,
                              global_step=i + 1)
            del all_identities, all_labels, score_matrix, labels_matrix
            torch.cuda.empty_cache()
